{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Machine Translation English-German Example Using SageMaker Seq2Seq\n",
    "\n",
    "1. [Introduction](#Introduction)\n",
    "2. [Setup](#Setup)\n",
    "3. [Download dataset and preprocess](#Download-dataset-and-preprocess)\n",
    "3. [Training the Machine Translation model](#Training-the-Machine-Translation-model)\n",
    "4. [Inference](#Inference)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "Welcome to our Machine Translation end-to-end example! In this demo, we will train a English-German translation model and will test the predictions on a few examples.\n",
    "\n",
    "SageMaker Seq2Seq algorithm is built on top of [Sockeye](https://github.com/awslabs/sockeye), a sequence-to-sequence framework for Neural Machine Translation based on MXNet. SageMaker Seq2Seq implements state-of-the-art encoder-decoder architectures which can also be used for tasks like Abstractive Summarization in addition to Machine Translation.\n",
    "\n",
    "To get started, we need to set up the environment with a few prerequisite steps, for permissions, configurations, and so on."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "Let's start by specifying:\n",
    "- The S3 bucket and prefix that you want to use for training and model data. **This should be within the same region as the Notebook Instance, training, and hosting.**\n",
    "- The IAM role arn used to give training and hosting access to your data. See the documentation for how to create these. Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the boto regexp in the cell below with a the appropriate full IAM role arn string(s)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "isConfigCell": true
   },
   "outputs": [],
   "source": [
    "# S3 bucket and prefix\n",
    "bucket = '<your_s3_bucket_name_here>'\n",
    "prefix = 'sagemaker/DEMO-seq2seq'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import boto3\n",
    "import re\n",
    "from sagemaker import get_execution_role\n",
    "\n",
    "role = get_execution_role()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll import the Python libraries we'll need for the remainder of the exercise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from time import gmtime, strftime\n",
    "import time\n",
    "import numpy as np\n",
    "import os\n",
    "import json\n",
    "\n",
    "# For plotting attention matrix later on\n",
    "import matplotlib\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download dataset and preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we will train a English to German translation model on a dataset from the\n",
    "[Conference on Machine Translation (WMT) 2017](http://www.statmt.org/wmt17/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "wget http://data.statmt.org/wmt17/translation-task/preprocessed/de-en/corpus.tc.de.gz & \\\n",
    "wget http://data.statmt.org/wmt17/translation-task/preprocessed/de-en/corpus.tc.en.gz & wait\n",
    "gunzip corpus.tc.de.gz & \\\n",
    "gunzip corpus.tc.en.gz & wait\n",
    "mkdir validation\n",
    "curl http://data.statmt.org/wmt17/translation-task/preprocessed/de-en/dev.tgz | tar xvzf - -C validation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please note that it is a common practise to split words into subwords using Byte Pair Encoding (BPE). Please refer to [this](https://github.com/awslabs/sockeye/tree/master/tutorials/wmt) tutorial if you are interested in performing BPE."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since training on the whole dataset might take several hours/days, for this demo, let us train on the **first 10,000 lines only**. Don't run the next cell if you want to train on the complete dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "!head -n 10000 corpus.tc.en > corpus.tc.en.small\n",
    "!head -n 10000 corpus.tc.de > corpus.tc.de.small"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's use the preprocessing script `create_vocab_proto.py` (provided with this notebook) to create vocabulary mappings (strings to integers) and convert these files to x-recordio-protobuf as required for training by SageMaker Seq2Seq.  \n",
    "Uncomment the cell below and run to see check the arguments this script expects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "# python3 create_vocab_proto.py -h"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The cell below does the preprocessing. If you are using the complete dataset, the script might take around 10-15 min on an m4.xlarge notebook instance. Remove \".small\" from the file names for training on full datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "%%bash\n",
    "python3 create_vocab_proto.py \\\n",
    "        --train-source corpus.tc.en.small \\\n",
    "        --train-target corpus.tc.de.small \\\n",
    "        --val-source validation/newstest2014.tc.en \\\n",
    "        --val-target validation/newstest2014.tc.de"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The script will output 4 files, namely:\n",
    "- train.rec : Contains source and target sentences for training in protobuf format\n",
    "- val.rec : Contains source and target sentences for validation in protobuf format\n",
    "- vocab.src.json : Vocabulary mapping (string to int) for source language (English in this example)\n",
    "- vocab.trg.json : Vocabulary mapping (string to int) for target language (German in this example)\n",
    "\n",
    "Let's upload the pre-processed dataset and vocabularies to S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def upload_to_s3(bucket, prefix, channel, file):\n",
    "    s3 = boto3.resource('s3')\n",
    "    data = open(file, \"rb\")\n",
    "    key = prefix + \"/\" + channel + '/' + file\n",
    "    s3.Bucket(bucket).put_object(Key=key, Body=data)\n",
    "\n",
    "upload_to_s3(bucket, prefix, 'train', 'train.rec')\n",
    "upload_to_s3(bucket, prefix, 'validation', 'val.rec')\n",
    "upload_to_s3(bucket, prefix, 'vocab', 'vocab.src.json')\n",
    "upload_to_s3(bucket, prefix, 'vocab', 'vocab.trg.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "region_name = boto3.Session().region_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sagemaker.amazon.amazon_estimator import get_image_uri\n",
    "container = get_image_uri(region_name, 'seq2seq')\n",
    "\n",
    "print('Using SageMaker Seq2Seq container: {} ({})'.format(container, region_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the Machine Translation model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "job_name = 'DEMO-seq2seq-en-de-' + strftime(\"%Y-%m-%d-%H\", gmtime())\n",
    "print(\"Training job\", job_name)\n",
    "\n",
    "create_training_params = \\\n",
    "{\n",
    "    \"AlgorithmSpecification\": {\n",
    "        \"TrainingImage\": container,\n",
    "        \"TrainingInputMode\": \"File\"\n",
    "    },\n",
    "    \"RoleArn\": role,\n",
    "    \"OutputDataConfig\": {\n",
    "        \"S3OutputPath\": \"s3://{}/{}/\".format(bucket, prefix)\n",
    "    },\n",
    "    \"ResourceConfig\": {\n",
    "        # Seq2Seq does not support multiple machines. Currently, it only supports single machine, multiple GPUs\n",
    "        \"InstanceCount\": 1,\n",
    "        \"InstanceType\": \"ml.p2.xlarge\", # We suggest one of [\"ml.p2.16xlarge\", \"ml.p2.8xlarge\", \"ml.p2.xlarge\"]\n",
    "        \"VolumeSizeInGB\": 50\n",
    "    },\n",
    "    \"TrainingJobName\": job_name,\n",
    "    \"HyperParameters\": {\n",
    "        # Please refer to the documentation for complete list of parameters\n",
    "        \"max_seq_len_source\": \"60\",\n",
    "        \"max_seq_len_target\": \"60\",\n",
    "        \"optimized_metric\": \"bleu\",\n",
    "        \"batch_size\": \"64\", # Please use a larger batch size (256 or 512) if using ml.p2.8xlarge or ml.p2.16xlarge\n",
    "        \"checkpoint_frequency_num_batches\": \"1000\",\n",
    "        \"rnn_num_hidden\": \"512\",\n",
    "        \"num_layers_encoder\": \"1\",\n",
    "        \"num_layers_decoder\": \"1\",\n",
    "        \"num_embed_source\": \"512\",\n",
    "        \"num_embed_target\": \"512\",\n",
    "        \"checkpoint_threshold\": \"3\",\n",
    "        \"max_num_batches\": \"2100\"\n",
    "        # Training will stop after 2100 iterations/batches.\n",
    "        # This is just for demo purposes. Remove the above parameter if you want a better model.\n",
    "    },\n",
    "    \"StoppingCondition\": {\n",
    "        \"MaxRuntimeInSeconds\": 48 * 3600\n",
    "    },\n",
    "    \"InputDataConfig\": [\n",
    "        {\n",
    "            \"ChannelName\": \"train\",\n",
    "            \"DataSource\": {\n",
    "                \"S3DataSource\": {\n",
    "                    \"S3DataType\": \"S3Prefix\",\n",
    "                    \"S3Uri\": \"s3://{}/{}/train/\".format(bucket, prefix),\n",
    "                    \"S3DataDistributionType\": \"FullyReplicated\"\n",
    "                }\n",
    "            },\n",
    "        },\n",
    "        {\n",
    "            \"ChannelName\": \"vocab\",\n",
    "            \"DataSource\": {\n",
    "                \"S3DataSource\": {\n",
    "                    \"S3DataType\": \"S3Prefix\",\n",
    "                    \"S3Uri\": \"s3://{}/{}/vocab/\".format(bucket, prefix),\n",
    "                    \"S3DataDistributionType\": \"FullyReplicated\"\n",
    "                }\n",
    "            },\n",
    "        },\n",
    "        {\n",
    "            \"ChannelName\": \"validation\",\n",
    "            \"DataSource\": {\n",
    "                \"S3DataSource\": {\n",
    "                    \"S3DataType\": \"S3Prefix\",\n",
    "                    \"S3Uri\": \"s3://{}/{}/validation/\".format(bucket, prefix),\n",
    "                    \"S3DataDistributionType\": \"FullyReplicated\"\n",
    "                }\n",
    "            },\n",
    "        }\n",
    "    ]\n",
    "}\n",
    "\n",
    "sagemaker_client = boto3.Session().client(service_name='sagemaker')\n",
    "sagemaker_client.create_training_job(**create_training_params)\n",
    "\n",
    "status = sagemaker_client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n",
    "print(status)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "status = sagemaker_client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n",
    "print(status)\n",
    "# if the job failed, determine why\n",
    "if status == 'Failed':\n",
    "    message = sagemaker_client.describe_training_job(TrainingJobName=job_name)['FailureReason']\n",
    "    print('Training failed with the following error: {}'.format(message))\n",
    "    raise Exception('Training job failed')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Now wait for the training job to complete and proceed to the next step after you see model artifacts in your S3 bucket."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can jump to [Use a pretrained model](#Use-a-pretrained-model) as training might take some time."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference\n",
    "\n",
    "A trained model does nothing on its own. We now want to use the model to perform inference. For this example, that means translating sentence(s) from English to German.\n",
    "This section involves several steps,\n",
    "- Create model - Create a model using the artifact (model.tar.gz) produced by training\n",
    "- Create Endpoint Configuration - Create a configuration defining an endpoint, using the above model\n",
    "- Create Endpoint - Use the configuration to create an inference endpoint.\n",
    "- Perform Inference - Perform inference on some input data using the endpoint.\n",
    "\n",
    "### Create model\n",
    "We now create a SageMaker Model from the training output. Using the model, we can then create an Endpoint Configuration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "use_pretrained_model = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use a pretrained model\n",
    "#### Please uncomment and run the cell below if you want to use a pretrained model, as training might take several hours/days to complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#use_pretrained_model = True\n",
    "#model_name = \"DEMO-pretrained-en-de-model\"\n",
    "#!curl https://s3-us-west-2.amazonaws.com/seq2seq-data/model.tar.gz > model.tar.gz\n",
    "#!curl https://s3-us-west-2.amazonaws.com/seq2seq-data/vocab.src.json > vocab.src.json\n",
    "#!curl https://s3-us-west-2.amazonaws.com/seq2seq-data/vocab.trg.json > vocab.trg.json\n",
    "#upload_to_s3(bucket, prefix, 'pretrained_model', 'model.tar.gz')\n",
    "#model_data = \"s3://{}/{}/pretrained_model/model.tar.gz\".format(bucket, prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "sage = boto3.client('sagemaker')\n",
    "\n",
    "if not use_pretrained_model:\n",
    "    info = sage.describe_training_job(TrainingJobName=job_name)\n",
    "    model_name=job_name\n",
    "    model_data = info['ModelArtifacts']['S3ModelArtifacts']\n",
    "\n",
    "print(model_name)\n",
    "print(model_data)\n",
    "\n",
    "primary_container = {\n",
    "    'Image': container,\n",
    "    'ModelDataUrl': model_data\n",
    "}\n",
    "\n",
    "create_model_response = sage.create_model(\n",
    "    ModelName = model_name,\n",
    "    ExecutionRoleArn = role,\n",
    "    PrimaryContainer = primary_container)\n",
    "\n",
    "print(create_model_response['ModelArn'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create endpoint configuration\n",
    "Use the model to create an endpoint configuration. The endpoint configuration also contains information about the type and number of EC2 instances to use when hosting the model.\n",
    "\n",
    "Since SageMaker Seq2Seq is based on Neural Nets, we could use an ml.p2.xlarge (GPU) instance, but for this example we will use a free tier eligible ml.m4.xlarge."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from time import gmtime, strftime\n",
    "\n",
    "endpoint_config_name = 'DEMO-Seq2SeqEndpointConfig-' + strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n",
    "print(endpoint_config_name)\n",
    "create_endpoint_config_response = sage.create_endpoint_config(\n",
    "    EndpointConfigName = endpoint_config_name,\n",
    "    ProductionVariants=[{\n",
    "        'InstanceType':'ml.m4.xlarge',\n",
    "        'InitialInstanceCount':1,\n",
    "        'ModelName':model_name,\n",
    "        'VariantName':'AllTraffic'}])\n",
    "\n",
    "print(\"Endpoint Config Arn: \" + create_endpoint_config_response['EndpointConfigArn'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create endpoint\n",
    "Lastly, we create the endpoint that serves up model, through specifying the name and configuration defined above. The end result is an endpoint that can be validated and incorporated into production applications. This takes 10-15 minutes to complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "import time\n",
    "\n",
    "endpoint_name = 'DEMO-Seq2SeqEndpoint-' + strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n",
    "print(endpoint_name)\n",
    "create_endpoint_response = sage.create_endpoint(\n",
    "    EndpointName=endpoint_name,\n",
    "    EndpointConfigName=endpoint_config_name)\n",
    "print(create_endpoint_response['EndpointArn'])\n",
    "\n",
    "resp = sage.describe_endpoint(EndpointName=endpoint_name)\n",
    "status = resp['EndpointStatus']\n",
    "print(\"Status: \" + status)\n",
    "\n",
    "# wait until the status has changed\n",
    "sage.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)\n",
    "\n",
    "# print the status of the endpoint\n",
    "endpoint_response = sage.describe_endpoint(EndpointName=endpoint_name)\n",
    "status = endpoint_response['EndpointStatus']\n",
    "print('Endpoint creation ended with EndpointStatus = {}'.format(status))\n",
    "\n",
    "if status != 'InService':\n",
    "    raise Exception('Endpoint creation failed.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you see the message,\n",
    "> Endpoint creation ended with EndpointStatus = InService\n",
    "\n",
    "then congratulations! You now have a functioning inference endpoint. You can confirm the endpoint configuration and status by navigating to the \"Endpoints\" tab in the AWS SageMaker console.  \n",
    "\n",
    "We will finally create a runtime object from which we can invoke the endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "runtime = boto3.client(service_name='runtime.sagemaker') "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Perform Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using JSON format for inference (Suggested for a single or small number of data instances)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Note that you don't have to convert string to text using the vocabulary mapping for inference using JSON mode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sentences = [\"you are so good !\",\n",
    "             \"can you drive a car ?\",\n",
    "             \"i want to watch a movie .\"\n",
    "            ]\n",
    "\n",
    "payload = {\"instances\" : []}\n",
    "for sent in sentences:\n",
    "    payload[\"instances\"].append({\"data\" : sent})\n",
    "\n",
    "response = runtime.invoke_endpoint(EndpointName=endpoint_name, \n",
    "                                   ContentType='application/json', \n",
    "                                   Body=json.dumps(payload))\n",
    "\n",
    "response = response[\"Body\"].read().decode(\"utf-8\")\n",
    "response = json.loads(response)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Retrieving the Attention Matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Passing `\"attention_matrix\":\"true\"` in `configuration` of the data instance will return the attention matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sentence = 'can you drive a car ?'\n",
    "\n",
    "payload = {\"instances\" : [{\n",
    "                            \"data\" : sentence,\n",
    "                            \"configuration\" : {\"attention_matrix\":\"true\"}\n",
    "                          }\n",
    "                         ]}\n",
    "\n",
    "response = runtime.invoke_endpoint(EndpointName=endpoint_name, \n",
    "                                   ContentType='application/json', \n",
    "                                   Body=json.dumps(payload))\n",
    "\n",
    "response = response[\"Body\"].read().decode(\"utf-8\")\n",
    "response = json.loads(response)['predictions'][0]\n",
    "\n",
    "source = sentence\n",
    "target = response[\"target\"]\n",
    "attention_matrix = np.array(response[\"matrix\"])\n",
    "\n",
    "print(\"Source: %s \\nTarget: %s\" % (source, target))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Define a function for plotting the attentioan matrix\n",
    "def plot_matrix(attention_matrix, target, source):\n",
    "    source_tokens = source.split()\n",
    "    target_tokens = target.split()\n",
    "    assert attention_matrix.shape[0] == len(target_tokens)\n",
    "    plt.imshow(attention_matrix.transpose(), interpolation=\"nearest\", cmap=\"Greys\")\n",
    "    plt.xlabel(\"target\")\n",
    "    plt.ylabel(\"source\")\n",
    "    plt.gca().set_xticks([i for i in range(0, len(target_tokens))])\n",
    "    plt.gca().set_yticks([i for i in range(0, len(source_tokens))])\n",
    "    plt.gca().set_xticklabels(target_tokens)\n",
    "    plt.gca().set_yticklabels(source_tokens)\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "plot_matrix(attention_matrix, target, source)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### Using Protobuf format for inference (Suggested for efficient bulk inference)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Reading the vocabulary mappings as this mode of inference accepts list of integers and returns list of integers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import io\n",
    "import tempfile\n",
    "from record_pb2 import Record\n",
    "from create_vocab_proto import vocab_from_json, reverse_vocab, write_recordio, list_to_record_bytes, read_next\n",
    "\n",
    "source = vocab_from_json(\"vocab.src.json\")\n",
    "target = vocab_from_json(\"vocab.trg.json\")\n",
    "\n",
    "source_rev = reverse_vocab(source)\n",
    "target_rev = reverse_vocab(target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sentences = [\"this is so cool\",\n",
    "            \"i am having dinner .\",\n",
    "            \"i am sitting in an aeroplane .\",\n",
    "            \"come let us go for a long drive .\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Converting the string to integers, followed by protobuf encoding:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Convert strings to integers using source vocab mapping. Out-of-vocabulary strings are mapped to 1 - the mapping for <unk>\n",
    "sentences = [[source.get(token, 1) for token in sentence.split()] for sentence in sentences]\n",
    "f = io.BytesIO()\n",
    "for sentence in sentences:\n",
    "    record = list_to_record_bytes(sentence, [])\n",
    "    write_recordio(f, record)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "response = runtime.invoke_endpoint(EndpointName=endpoint_name, \n",
    "                                   ContentType='application/x-recordio-protobuf', \n",
    "                                   Body=f.getvalue())\n",
    "\n",
    "response = response[\"Body\"].read()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, parse the protobuf response and convert list of integers back to strings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def _parse_proto_response(received_bytes):\n",
    "    output_file = tempfile.NamedTemporaryFile()\n",
    "    output_file.write(received_bytes)\n",
    "    output_file.flush()\n",
    "    target_sentences = []\n",
    "    with open(output_file.name, 'rb') as datum:\n",
    "        next_record = True\n",
    "        while next_record:\n",
    "            next_record = read_next(datum)\n",
    "            if next_record:\n",
    "                rec = Record()\n",
    "                rec.ParseFromString(next_record)\n",
    "                target = list(rec.features[\"target\"].int32_tensor.values)\n",
    "                target_sentences.append(target)\n",
    "            else:\n",
    "                break\n",
    "    return target_sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "targets = _parse_proto_response(response)\n",
    "resp = [\" \".join([target_rev.get(token, \"<unk>\") for token in sentence]) for\n",
    "                               sentence in targets]\n",
    "print(resp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stop / Close the Endpoint (Optional)\n",
    "\n",
    "Finally, we should delete the endpoint before we close the notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sage.delete_endpoint(EndpointName=endpoint_name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  },
  "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
